{
 "cells": [
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-17T12:40:27.050053Z",
     "start_time": "2025-01-17T12:40:27.045699Z"
    }
   },
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "print(sys.version_info)\n",
    "for module in mpl, np, pd, sklearn, torch:\n",
    "    print(module.__name__, module.__version__)\n",
    "    \n",
    "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "print(device)\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sys.version_info(major=3, minor=12, micro=3, releaselevel='final', serial=0)\n",
      "matplotlib 3.10.0\n",
      "numpy 1.26.4\n",
      "pandas 2.2.3\n",
      "sklearn 1.6.0\n",
      "torch 2.5.1+cpu\n",
      "cpu\n"
     ]
    }
   ],
   "execution_count": 10
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载数据"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-17T12:40:27.102186Z",
     "start_time": "2025-01-17T12:40:27.075059Z"
    }
   },
   "source": [
    "from torchvision import datasets\n",
    "from torchvision.transforms import ToTensor\n",
    "\n",
    "# fashion_mnist图像分类数据集\n",
    "train_ds = datasets.FashionMNIST(\n",
    "    root=\"data\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=ToTensor()\n",
    ")\n",
    "\n",
    "test_ds = datasets.FashionMNIST(\n",
    "    root=\"data\",\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=ToTensor()\n",
    ")\n",
    "\n",
    "# torchvision 数据集里没有提供训练集和验证集的划分\n",
    "# 当然也可以用 torch.utils.data.Dataset 实现人为划分\n",
    "# 从数据集到dataloader\n",
    "train_loader = torch.utils.data.DataLoader(train_ds, batch_size=16, shuffle=True)\n",
    "val_loader = torch.utils.data.DataLoader(test_ds, batch_size=16, shuffle=False)"
   ],
   "outputs": [],
   "execution_count": 11
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-17T12:40:27.105199Z",
     "start_time": "2025-01-17T12:40:27.102186Z"
    }
   },
   "source": [
    "from torchvision.transforms import Normalize\n",
    "\n",
    "# 遍历train_ds得到每张图片，计算每个通道的均值和方差\n",
    "def cal_mean_std(ds):\n",
    "    mean = 0.\n",
    "    std = 0.\n",
    "    for img, _ in ds:\n",
    "        mean += img.mean(dim=(1, 2))\n",
    "        std += img.std(dim=(1, 2))\n",
    "    mean /= len(ds)\n",
    "    std /= len(ds)\n",
    "    return mean, std\n",
    "\n",
    "\n",
    "# print(cal_mean_std(train_ds))\n",
    "# 0.2860， 0.3205\n",
    "transforms = nn.Sequential(\n",
    "    Normalize([0.2860], [0.3205])\n",
    ")\n"
   ],
   "outputs": [],
   "execution_count": 12
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义模型\n",
    "\n",
    "这里我们没有用`nn.Linear`的默认初始化，而是采用了xavier均匀分布去初始化全连接层的权重\n",
    "\n",
    "xavier初始化出自论文 《Understanding the difficulty of training deep feedforward neural networks》，适用于使用`tanh`和`sigmoid`激活函数的方法。当然，我们这里的模型采用的是`relu`激活函数，采用He初始化（何凯明初始化）会更加合适。感兴趣的同学可以自己动手修改并比对效果。\n",
    "\n",
    "|神经网络层数|初始化方式|early stop at epoch| val_loss | vla_acc|\n",
    "|-|-|-|-|-|\n",
    "|20|默认|\n",
    "|20|xaviier_uniform|\n",
    "|20|he_uniform|\n",
    "|...|\n",
    "\n",
    "He初始化出自论文 《Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification》"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-17T12:40:27.112616Z",
     "start_time": "2025-01-17T12:40:27.105199Z"
    }
   },
   "source": [
    "class NeuralNetwork(nn.Module):\n",
    "    def __init__(self, layers_num=2):\n",
    "        super().__init__()\n",
    "        self.transforms = transforms\n",
    "        self.flatten = nn.Flatten()\n",
    "        # 多加几层\n",
    "        self.linear_relu_stack = nn.Sequential(\n",
    "            nn.Linear(28 * 28, 100),  # in_features=784, out_features=300\n",
    "            nn.ReLU(),\n",
    "            nn.AlphaDropout(p=0.2) # 增加dropout，p=0.2表示以0.2的概率将某些神经元置0，防止过拟合\n",
    "        )\n",
    "        # 加19层\n",
    "        for i in range(1, layers_num):\n",
    "            self.linear_relu_stack.add_module(f\"Linear_{i}\", nn.Linear(100, 100))\n",
    "            self.linear_relu_stack.add_module(f\"relu\", nn.ReLU())\n",
    "            if i<3:\n",
    "                # self.linear_relu_stack.add_module(f\"dropout_{i}\", nn.AlphaDropout(p=0.2))\n",
    "                pass\n",
    "             # 增加dropout\n",
    "        # 输出层\n",
    "        self.linear_relu_stack.add_module(\"Output Layer\", nn.Linear(100, 10))\n",
    "        \n",
    "        # 初始化权重\n",
    "        self.init_weights()\n",
    "        \n",
    "    def init_weights(self):\n",
    "        \"\"\"使用 xavier 均匀分布来初始化全连接层的权重 W\"\"\"\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Linear):\n",
    "                nn.init.xavier_uniform_(m.weight)\n",
    "                nn.init.zeros_(m.bias)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # x.shape [batch size, 1, 28, 28]\n",
    "        x = self.transforms(x)\n",
    "        x = self.flatten(x)  \n",
    "        # 展平后 x.shape [batch size, 28 * 28]\n",
    "        logits = self.linear_relu_stack(x)\n",
    "        # logits.shape [batch size, 10]\n",
    "        return logits\n",
    "\n",
    "print(f\"{'layer_name':^40}\\tparamerters num\")\n",
    "for idx, (key, value) in enumerate(NeuralNetwork(20).named_parameters()):\n",
    "    print(\"{:<40}\\t{:^10}\".format(key, np.prod(value.shape)))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "               layer_name               \tparamerters num\n",
      "linear_relu_stack.0.weight              \t  78400   \n",
      "linear_relu_stack.0.bias                \t   100    \n",
      "linear_relu_stack.Linear_1.weight       \t  10000   \n",
      "linear_relu_stack.Linear_1.bias         \t   100    \n",
      "linear_relu_stack.Linear_2.weight       \t  10000   \n",
      "linear_relu_stack.Linear_2.bias         \t   100    \n",
      "linear_relu_stack.Linear_3.weight       \t  10000   \n",
      "linear_relu_stack.Linear_3.bias         \t   100    \n",
      "linear_relu_stack.Linear_4.weight       \t  10000   \n",
      "linear_relu_stack.Linear_4.bias         \t   100    \n",
      "linear_relu_stack.Linear_5.weight       \t  10000   \n",
      "linear_relu_stack.Linear_5.bias         \t   100    \n",
      "linear_relu_stack.Linear_6.weight       \t  10000   \n",
      "linear_relu_stack.Linear_6.bias         \t   100    \n",
      "linear_relu_stack.Linear_7.weight       \t  10000   \n",
      "linear_relu_stack.Linear_7.bias         \t   100    \n",
      "linear_relu_stack.Linear_8.weight       \t  10000   \n",
      "linear_relu_stack.Linear_8.bias         \t   100    \n",
      "linear_relu_stack.Linear_9.weight       \t  10000   \n",
      "linear_relu_stack.Linear_9.bias         \t   100    \n",
      "linear_relu_stack.Linear_10.weight      \t  10000   \n",
      "linear_relu_stack.Linear_10.bias        \t   100    \n",
      "linear_relu_stack.Linear_11.weight      \t  10000   \n",
      "linear_relu_stack.Linear_11.bias        \t   100    \n",
      "linear_relu_stack.Linear_12.weight      \t  10000   \n",
      "linear_relu_stack.Linear_12.bias        \t   100    \n",
      "linear_relu_stack.Linear_13.weight      \t  10000   \n",
      "linear_relu_stack.Linear_13.bias        \t   100    \n",
      "linear_relu_stack.Linear_14.weight      \t  10000   \n",
      "linear_relu_stack.Linear_14.bias        \t   100    \n",
      "linear_relu_stack.Linear_15.weight      \t  10000   \n",
      "linear_relu_stack.Linear_15.bias        \t   100    \n",
      "linear_relu_stack.Linear_16.weight      \t  10000   \n",
      "linear_relu_stack.Linear_16.bias        \t   100    \n",
      "linear_relu_stack.Linear_17.weight      \t  10000   \n",
      "linear_relu_stack.Linear_17.bias        \t   100    \n",
      "linear_relu_stack.Linear_18.weight      \t  10000   \n",
      "linear_relu_stack.Linear_18.bias        \t   100    \n",
      "linear_relu_stack.Linear_19.weight      \t  10000   \n",
      "linear_relu_stack.Linear_19.bias        \t   100    \n",
      "linear_relu_stack.Output Layer.weight   \t   1000   \n",
      "linear_relu_stack.Output Layer.bias     \t    10    \n"
     ]
    }
   ],
   "execution_count": 13
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-17T12:40:27.116042Z",
     "start_time": "2025-01-17T12:40:27.112616Z"
    }
   },
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluating(model, dataloader, loss_fct):\n",
    "    loss_list = []\n",
    "    pred_list = []\n",
    "    label_list = []\n",
    "    for datas, labels in dataloader:\n",
    "        datas = datas.to(device)\n",
    "        labels = labels.to(device)\n",
    "        # 前向计算\n",
    "        logits = model(datas)\n",
    "        loss = loss_fct(logits, labels)         # 验证集损失\n",
    "        loss_list.append(loss.item())\n",
    "        \n",
    "        preds = logits.argmax(axis=-1)    # 验证集预测\n",
    "        pred_list.extend(preds.cpu().numpy().tolist())\n",
    "        label_list.extend(labels.cpu().numpy().tolist())\n",
    "        \n",
    "    acc = accuracy_score(label_list, pred_list)\n",
    "    return np.mean(loss_list), acc\n"
   ],
   "outputs": [],
   "execution_count": 14
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-17T12:40:27.120212Z",
     "start_time": "2025-01-17T12:40:27.116042Z"
    }
   },
   "source": [
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "\n",
    "class TensorBoardCallback:\n",
    "    def __init__(self, log_dir, flush_secs=10):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            log_dir (str): dir to write log.\n",
    "            flush_secs (int, optional): write to dsk each flush_secs seconds. Defaults to 10.\n",
    "        \"\"\"\n",
    "        self.writer = SummaryWriter(log_dir=log_dir, flush_secs=flush_secs)\n",
    "\n",
    "    def draw_model(self, model, input_shape):\n",
    "        self.writer.add_graph(model, input_to_model=torch.randn(input_shape))\n",
    "        \n",
    "    def add_loss_scalars(self, step, loss, val_loss):\n",
    "        self.writer.add_scalars(\n",
    "            main_tag=\"training/loss\", \n",
    "            tag_scalar_dict={\"loss\": loss, \"val_loss\": val_loss},\n",
    "            global_step=step,\n",
    "            )\n",
    "        \n",
    "    def add_acc_scalars(self, step, acc, val_acc):\n",
    "        self.writer.add_scalars(\n",
    "            main_tag=\"training/accuracy\",\n",
    "            tag_scalar_dict={\"accuracy\": acc, \"val_accuracy\": val_acc},\n",
    "            global_step=step,\n",
    "        )\n",
    "        \n",
    "    def add_lr_scalars(self, step, learning_rate):\n",
    "        self.writer.add_scalars(\n",
    "            main_tag=\"training/learning_rate\",\n",
    "            tag_scalar_dict={\"learning_rate\": learning_rate},\n",
    "            global_step=step,\n",
    "            \n",
    "        )\n",
    "    \n",
    "    def __call__(self, step, **kwargs):\n",
    "        # add loss\n",
    "        loss = kwargs.pop(\"loss\", None)\n",
    "        val_loss = kwargs.pop(\"val_loss\", None)\n",
    "        if loss is not None and val_loss is not None:\n",
    "            self.add_loss_scalars(step, loss, val_loss)\n",
    "        # add acc\n",
    "        acc = kwargs.pop(\"acc\", None)\n",
    "        val_acc = kwargs.pop(\"val_acc\", None)\n",
    "        if acc is not None and val_acc is not None:\n",
    "            self.add_acc_scalars(step, acc, val_acc)\n",
    "        # add lr\n",
    "        learning_rate = kwargs.pop(\"lr\", None)\n",
    "        if learning_rate is not None:\n",
    "            self.add_lr_scalars(step, learning_rate)\n"
   ],
   "outputs": [],
   "execution_count": 15
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-17T12:40:27.123665Z",
     "start_time": "2025-01-17T12:40:27.120212Z"
    }
   },
   "source": [
    "class SaveCheckpointsCallback:\n",
    "    def __init__(self, save_dir, save_step=5000, save_best_only=True):\n",
    "        \"\"\"\n",
    "        Save checkpoints each save_epoch epoch. \n",
    "        We save checkpoint by epoch in this implementation.\n",
    "        Usually, training scripts with pytorch evaluating model and save checkpoint by step.\n",
    "\n",
    "        Args:\n",
    "            save_dir (str): dir to save checkpoint\n",
    "            save_epoch (int, optional): the frequency to save checkpoint. Defaults to 1.\n",
    "            save_best_only (bool, optional): If True, only save the best model or save each model at every epoch.\n",
    "        \"\"\"\n",
    "        self.save_dir = save_dir\n",
    "        self.save_step = save_step\n",
    "        self.save_best_only = save_best_only\n",
    "        self.best_metrics = -1\n",
    "        \n",
    "        # mkdir\n",
    "        if not os.path.exists(self.save_dir):\n",
    "            os.mkdir(self.save_dir)\n",
    "        \n",
    "    def __call__(self, step, state_dict, metric=None):\n",
    "        if step % self.save_step > 0:\n",
    "            return\n",
    "        \n",
    "        if self.save_best_only:\n",
    "            assert metric is not None\n",
    "            if metric >= self.best_metrics:\n",
    "                # save checkpoints\n",
    "                torch.save(state_dict, os.path.join(self.save_dir, \"best.ckpt\"))\n",
    "                # update best metrics\n",
    "                self.best_metrics = metric\n",
    "        else:\n",
    "            torch.save(state_dict, os.path.join(self.save_dir, f\"{step}.ckpt\"))\n",
    "\n"
   ],
   "outputs": [],
   "execution_count": 16
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-17T12:40:27.127223Z",
     "start_time": "2025-01-17T12:40:27.123665Z"
    }
   },
   "source": [
    "class EarlyStopCallback:\n",
    "    def __init__(self, patience=5, min_delta=0.01):\n",
    "        \"\"\"\n",
    "\n",
    "        Args:\n",
    "            patience (int, optional): Number of epochs with no improvement after which training will be stopped.. Defaults to 5.\n",
    "            min_delta (float, optional): Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute \n",
    "                change of less than min_delta, will count as no improvement. Defaults to 0.01.\n",
    "        \"\"\"\n",
    "        self.patience = patience\n",
    "        self.min_delta = min_delta\n",
    "        self.best_metric = -1\n",
    "        self.counter = 0\n",
    "        \n",
    "    def __call__(self, metric):\n",
    "        if metric >= self.best_metric + self.min_delta:\n",
    "            # update best metric\n",
    "            self.best_metric = metric\n",
    "            # reset counter \n",
    "            self.counter = 0\n",
    "        else: \n",
    "            self.counter += 1\n",
    "            \n",
    "    @property\n",
    "    def early_stop(self):\n",
    "        return self.counter >= self.patience\n"
   ],
   "outputs": [],
   "execution_count": 17
  },
  {
   "cell_type": "code",
   "metadata": {
    "jupyter": {
     "is_executing": true
    },
    "ExecuteTime": {
     "start_time": "2025-01-17T12:40:27.127223Z"
    }
   },
   "source": [
    "# 训练\n",
    "def training(\n",
    "    model, \n",
    "    train_loader, \n",
    "    val_loader, \n",
    "    epoch, \n",
    "    loss_fct, \n",
    "    optimizer, \n",
    "    tensorboard_callback=None,\n",
    "    save_ckpt_callback=None,\n",
    "    early_stop_callback=None,\n",
    "    eval_step=500,\n",
    "    ):\n",
    "    record_dict = {\n",
    "        \"train\": [],\n",
    "        \"val\": []\n",
    "    }\n",
    "    \n",
    "    global_step = 0\n",
    "    model.train()\n",
    "    with tqdm(total=epoch * len(train_loader)) as pbar:\n",
    "        for epoch_id in range(epoch):\n",
    "            # training\n",
    "            for datas, labels in train_loader:\n",
    "                datas = datas.to(device)\n",
    "                labels = labels.to(device)\n",
    "                # 梯度清空\n",
    "                optimizer.zero_grad()\n",
    "                # 模型前向计算\n",
    "                logits = model(datas)\n",
    "                # 计算损失\n",
    "                loss = loss_fct(logits, labels)\n",
    "                # 梯度回传\n",
    "                loss.backward()\n",
    "                # 调整优化器，包括学习率的变动等\n",
    "                optimizer.step()\n",
    "                preds = logits.argmax(axis=-1)\n",
    "            \n",
    "                acc = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())    \n",
    "                loss = loss.cpu().item()\n",
    "                # record\n",
    "                \n",
    "                record_dict[\"train\"].append({\n",
    "                    \"loss\": loss, \"acc\": acc, \"step\": global_step\n",
    "                })\n",
    "                \n",
    "                # evaluating\n",
    "                if global_step % eval_step == 0:\n",
    "                    model.eval()\n",
    "                    val_loss, val_acc = evaluating(model, val_loader, loss_fct)\n",
    "                    record_dict[\"val\"].append({\n",
    "                        \"loss\": val_loss, \"acc\": val_acc, \"step\": global_step\n",
    "                    })\n",
    "                    model.train()\n",
    "                    \n",
    "                    # 1. 使用 tensorboard 可视化\n",
    "                    if tensorboard_callback is not None:\n",
    "                        tensorboard_callback(\n",
    "                            global_step, \n",
    "                            loss=loss, val_loss=val_loss,\n",
    "                            acc=acc, val_acc=val_acc,\n",
    "                            lr=optimizer.param_groups[0][\"lr\"],\n",
    "                            )\n",
    "                    \n",
    "                    # 2. 保存模型权重 save model checkpoint\n",
    "                    if save_ckpt_callback is not None:\n",
    "                        save_ckpt_callback(global_step, model.state_dict(), metric=val_acc)\n",
    "\n",
    "                    # 3. 早停 Early Stop\n",
    "                    if early_stop_callback is not None:\n",
    "                        early_stop_callback(val_acc)\n",
    "                        if early_stop_callback.early_stop:\n",
    "                            print(f\"Early stop at epoch {epoch_id} / global_step {global_step}\")\n",
    "                            return record_dict\n",
    "                    \n",
    "                # udate step\n",
    "                global_step += 1\n",
    "                pbar.update(1)\n",
    "                pbar.set_postfix({\"epoch\": epoch_id})\n",
    "        \n",
    "    return record_dict\n",
    "        \n",
    "\n",
    "epoch = 100\n",
    "\n",
    "model = NeuralNetwork(layers_num=10)\n",
    "\n",
    "# 1. 定义损失函数 采用交叉熵损失\n",
    "loss_fct = nn.CrossEntropyLoss()\n",
    "# 2. 定义优化器 采用SGD\n",
    "# Optimizers specified in the torch.optim package\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n",
    "\n",
    "# 1. tensorboard 可视化\n",
    "tensorboard_callback = TensorBoardCallback(\"runs/dropout\")\n",
    "tensorboard_callback.draw_model(model, [1, 28, 28])\n",
    "# 2. save best\n",
    "save_ckpt_callback = SaveCheckpointsCallback(\"checkpoints/dropout\", save_best_only=True)\n",
    "# 3. early stop\n",
    "early_stop_callback = EarlyStopCallback(patience=10, min_delta=0.001)\n",
    "\n",
    "model = model.to(device)\n",
    "record = training(\n",
    "    model, \n",
    "    train_loader, \n",
    "    val_loader, \n",
    "    epoch, \n",
    "    loss_fct, \n",
    "    optimizer, \n",
    "    tensorboard_callback=tensorboard_callback,\n",
    "    save_ckpt_callback=save_ckpt_callback,\n",
    "    early_stop_callback=early_stop_callback,\n",
    "    eval_step=len(train_loader)\n",
    "    )"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "  0%|          | 0/375000 [00:00<?, ?it/s]"
      ],
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "2856f7aaa91241379aadcc009c4f96be"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {
    "jupyter": {
     "is_executing": true
    }
   },
   "source": [
    "#画线要注意的是损失是不一定在零到1之间的\n",
    "def plot_learning_curves(record_dict, sample_step=500):\n",
    "    # build DataFrame\n",
    "    train_df = pd.DataFrame(record_dict[\"train\"]).set_index(\"step\").iloc[::sample_step]\n",
    "    val_df = pd.DataFrame(record_dict[\"val\"]).set_index(\"step\")\n",
    "\n",
    "    # plot\n",
    "    fig_num = len(train_df.columns)\n",
    "    fig, axs = plt.subplots(1, fig_num, figsize=(6 * fig_num, 5))\n",
    "    for idx, item in enumerate(train_df.columns):    \n",
    "        axs[idx].plot(train_df.index, train_df[item], label=f\"train_{item}\")\n",
    "        axs[idx].plot(val_df.index, val_df[item], label=f\"val_{item}\")\n",
    "        axs[idx].grid()\n",
    "        axs[idx].legend()\n",
    "        axs[idx].set_xlabel(\"step\")\n",
    "    \n",
    "    plt.show()\n",
    "\n",
    "plot_learning_curves(record, sample_step=10000)  #横坐标是 steps"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {
    "jupyter": {
     "is_executing": true
    }
   },
   "source": [
    "# dataload for evaluating\n",
    "#推理时alphadropout的p值是0，通过加dropout，一定程度解决了过拟合\n",
    "# load checkpoints\n",
    "model.load_state_dict(torch.load(\"checkpoints/dropout/best.ckpt\", weights_only=True,map_location=\"cpu\"))\n",
    "\n",
    "model.eval()\n",
    "loss, acc = evaluating(model, val_loader, loss_fct)\n",
    "print(f\"loss:     {loss:.4f}\\naccuracy: {acc:.4f}\")"
   ],
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
